import  torch
from    torch import nn
from    torch import optim
from    torch.nn import functional as F
from    torch.utils.data import TensorDataset, DataLoader
from    torch import optim
import  numpy as np
import time
import copy

from    learner import Learner
from    copy import deepcopy
from functions import *

class Meta(nn.Module):
    """
    Meta Learner
    """
    def __init__(self, args, config):
        super(Meta, self).__init__()

        self.update_lr = args.update_lr
        self.meta_lr = args.meta_lr
        self.n_way = args.n_way
        self.n_spt = args.n_spt
        self.n_qry = args.n_qry
        self.update_step = args.update_step
        self.update_step_test = args.update_step_test
        self.args = args

        self.net = Learner(config, args.imgc, args.imgsz)
        self.meta_optim = optim.Adam(self.net.parameters(), lr=self.meta_lr)

    def forward(self, step, x_spt, y_spt, x_qry, y_qry, device):

        if step != 0 and step != 20000 and step % 5000 == 0:
            for g in self.meta_optim.param_groups:
                g['lr'] = 0.1 * g['lr']

        num_user, setsz, c_, h, w = x_spt.size()
        n_class = len(torch.unique(y_spt))
        avg_weight = self.net.parameters()
        global_proto = None

        for round in range(1, self.args.round + 1):
            weights = []
            prototypes = torch.ones((num_user, n_class, 32), dtype=float).to(device)
            for i in range(num_user):
                _x_spt = x_spt[i]; _y_spt = y_spt[i]
                sup_feat = self.net(_x_spt, vars=avg_weight, bn_training=True).squeeze()  # [30,32]
                prototype = self.make_prototype(F.avg_pool2d(sup_feat,6,1,0).squeeze(), _y_spt, device)  # [5,32]
                prob = PN_pred(prototype, F.avg_pool2d(sup_feat,6,1,0).squeeze())  # [25,5]
                Lloss = F.cross_entropy(prob, _y_spt)
                Gloss = self.GlobalLoss(global_proto, sup_feat, _y_spt)
                loss = Lloss + 0.2 * Gloss
                grad = torch.autograd.grad(loss, self.net.parameters(), retain_graph=True)
                fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters())))
                weights.append(fast_weights)
                prototypes[i] = prototype
            avg_weight = average_weights(weights)
            global_proto = average_prototypes(prototypes)

        qry_feat = self.net(x_qry, vars=avg_weight, bn_training=True).squeeze()  # [100,32]
        prob = PN_pred(global_proto, F.avg_pool2d(qry_feat,6,1,0).squeeze())  # [25,5]
        meta_loss = F.cross_entropy(prob, y_qry)

        self.meta_optim.zero_grad()
        meta_loss.backward()
        self.meta_optim.step()

        with torch.no_grad():
            pred_q2 = prob.argmax(dim=1)
            acc = (pred_q2 == y_qry).float().mean()

        return acc.item()

    def finetunning(self, x_spt, y_spt, x_qry, y_qry, device):
        # x_spt : [10,3,75,3,84,84], x_qry : [10,3,25,3,84,84]
        # in order to not ruin the state of running_mean/variance and bn_weight/bias
        # we finetunning on the copied model instead of self.net
        net = deepcopy(self.net)
        num_user, setsz, c_, h, w = x_spt.size()
        n_class = len(torch.unique(y_spt))
        global_proto = None

        for round in range(1, self.args.round + 1):
            weights = []
            prototypes = torch.ones((num_user, n_class, 32), dtype=float).to(device)
            for i in range(num_user):
                _x_spt = x_spt[i]; _y_spt = y_spt[i]
                local_net = deepcopy(net)
                optimizer = torch.optim.SGD(local_net.parameters(), lr=self.update_lr)
                sup_feat = net(_x_spt, vars=local_net.parameters(), bn_training=True).squeeze()  # [25,32]
                prototype = make_prototype(F.avg_pool2d(sup_feat,6,1,0).squeeze(), _y_spt, device)  # [5,32]
                prob = PN_pred(prototype, F.avg_pool2d(sup_feat,6,1,0).squeeze())  # [25,5]
                Lloss = F.cross_entropy(prob, _y_spt)
                Gloss = self.GlobalLoss(global_proto, sup_feat, _y_spt)
                loss = Lloss + 0.2 * Gloss
                loss.backward(retain_graph=True)
                optimizer.step()
                weights.append(local_net.state_dict())
                prototypes[i] = prototype
            w_glob = FedAvg(weights)
            global_proto = average_prototypes(prototypes)
            net.load_state_dict(w_glob)

        with torch.no_grad():
            avg_prototype = average_prototypes(prototypes)
            qry_feat = net(x_qry, vars=net.parameters(), bn_training=True).squeeze()  # [100,32]
            prob = PN_pred(avg_prototype, F.avg_pool2d(qry_feat,6,1,0).squeeze())  # [25,5]
            pred_q2 = prob.argmax(dim=1)
            acc = (pred_q2 == y_qry).float().mean()

        del net
        return acc.item()


    def make_prototype(self, feats, y_spt, device):
        labels = torch.unique(y_spt)
        C = feats.size(1)
        prototypes = torch.ones((len(labels), C), dtype=float).to(device)
        for label in labels:
            pos = (y_spt == label)
            feat = feats[pos]
            prototype = feat.mean(dim=0)
            prototypes[label] = prototype
        return prototypes

    def GlobalLoss(self, global_proto, sup_feat, _y_spt):
        # global proto [5,32]
        # sup_feat [25,32,6,6]
        if global_proto == None:
            return 0
        n, c, h, w = sup_feat.shape
        input1 = global_proto.unsqueeze(dim=0)[(...,)+(None,)*2]
        input2 = sup_feat.unsqueeze(dim=1)
        dist = -(input2-input1).pow(2).sum(dim=2)
        label = torch.ones_like(dist[:,0,:,:], dtype=int)
        for i in range(len(_y_spt)):
            label[i] = _y_spt[i]
        pred = dist.permute([0,2,3,1]).contiguous().view(-1,5) # [900,5]
        label = label.flatten()

        loss = F.cross_entropy(pred,label)
        return loss

def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(w))
    return w_avg


def average_weights(weights):
    out = weights[0]
    for idx in range(1, len(weights)):
        for widx in range(len(out)):
            out[widx] = out[widx] + weights[idx][widx]

    for widx in range(len(out)):
        out[widx] = out[widx] / len(weights)

    return out

def make_prototype(feats, y_spt, device):
    labels = torch.unique(y_spt)
    C = feats.size(1)
    prototypes = torch.ones((len(labels),C), dtype=float).to(device)
    for label in labels:
        pos = (y_spt == label)
        feat = feats[pos]
        prototype = feat.mean(dim=0)
        prototypes[label] = prototype
    return prototypes


def PN_pred(prototype, qry_feat):
    # qryfeat = [25,32], prototype = [5,32]
    distance = qry_feat.unsqueeze(dim=1) - prototype # [25,5,32]
    distance = distance.pow(2).sum(dim=2)
    return -distance # [25,5]


def average_prototypes(prototypes): # num_user, n_class, c_
    out = prototypes[0] # [5,32]
    for idx in range(1, len(prototypes)):
        out = out + prototypes[idx]
    out = out / len(prototypes)

    return out



def main():
    pass


if __name__ == '__main__':
    main()
